Skip to content

feat: configurable KVCache step size and pre-allocation#1038

Closed
Thump604 wants to merge 1 commit intoml-explore:mainfrom
Thump604:feat/kvcache-preallocation
Closed

feat: configurable KVCache step size and pre-allocation#1038
Thump604 wants to merge 1 commit intoml-explore:mainfrom
Thump604:feat/kvcache-preallocation

Conversation

@Thump604
Copy link
Copy Markdown

Summary

Add optional step and max_size parameters to KVCache (and step to QuantizedKVCache) to reduce memory churn during generation.

Problem

KVCache allocates buffer in fixed 256-token increments. When the buffer fills, it allocates a new chunk and concatenates with the existing buffer. For a model with 12 attention layers, each boundary crossing creates 24 concatenation operations (12 layers × K+V). During concatenation, both old and new buffers are live simultaneously, causing a transient memory spike.

On M2 Ultra 128GB with a 122B MoE model (~82 GB weights), these spikes contribute to Metal OOM on long generations (4000+ tokens). The repeated allocate-concatenate-free cycle also creates GPU sync points that reduce pipeline efficiency.

Changes

KVCache.__init__(max_size=None, step=None):

  • step: Override the 256-token allocation increment. Larger values (e.g., 1024) reduce boundary crossings 4×.
  • max_size: Pre-allocate the full buffer on first use. Eliminates ALL subsequent reallocations and concatenations. Ideal when max context is known from server configuration.

QuantizedKVCache.__init__(..., step=None):

  • Same step override for quantized caches.

make_prompt_cache(model, max_kv_size=None, max_context=None):

  • New max_context parameter passes through to KVCache(max_size=max_context).

All parameters are optional with backward-compatible defaults. Existing code calling KVCache() is unaffected.

Example usage

# Server knows max context is 128K — pre-allocate, zero reallocs
cache = [KVCache(max_size=131072) for _ in model.layers]

# Or via make_prompt_cache
cache = make_prompt_cache(model, max_context=131072)

# Or just increase the step size
cache = [KVCache(step=1024) for _ in model.layers]

Test plan

  • Existing tests pass unchanged (backward compatible defaults)
  • KVCache() behaves identically to before
  • KVCache(max_size=1024) pre-allocates on first update, no subsequent reallocs
  • KVCache(step=1024) reallocates every 1024 tokens instead of 256
  • Memory profile shows flat allocation with max_size vs sawtooth without

Add optional parameters to KVCache and QuantizedKVCache:

- `step` (int): Override the class-level step size (default 256).
  Larger values reduce the number of boundary reallocations and
  GPU sync points during generation. For example, step=1024 reduces
  reallocation frequency 4x.

- `max_size` (int, KVCache only): Pre-allocate the buffer to hold
  this many tokens on first use. Eliminates ALL subsequent boundary
  reallocations and concatenations. When the maximum context length
  is known (e.g., from server configuration), this avoids the
  repeated allocate-concatenate-free cycle that causes transient
  memory spikes and GPU sync points.

Also adds `max_context` parameter to `make_prompt_cache()` to pass
through to KVCache constructors.

All parameters are optional with backward-compatible defaults.
Existing code calling `KVCache()` or `make_prompt_cache(model)` is
unaffected.

Motivation: On M2 Ultra 128GB with a 122B MoE model (~82 GB weights),
the repeated KV cache boundary reallocations (every 256 tokens across
12 attention layers) create transient memory spikes of ~2x the cache
size at each boundary. With step=256, a 4000-token generation crosses
15 boundaries, each requiring 24 concatenation operations (12 layers
x K+V). Pre-allocation eliminates this entirely.
@Thump604
Copy link
Copy Markdown
Author

@angeloskath @awni — this PR has had no maintainer review since it was opened. Is there a concern with the approach?

It's a small, backwards-compatible change — adds optional step and max_size parameters to KVCache. Default behavior is unchanged. Reduces memory churn from repeated 256-token buffer concatenations during long generation on large models.

@angeloskath
Copy link
Copy Markdown
Member

Sorry for the late review.

We used to have it as a constructor argument, I am not sure it is particularly helpful. The memory issues can be fixed by clearing the cache which we should be doing diligently now. Counterintuitively, using a large array from the start is not optimal because our reads are jumping around unnecessarily which ends up being noticeably slower.

During concatenation, both old and new buffers are live simultaneously, causing a transient memory spike.

That is not quite true by the way. The old buffers will be freed as we go so the memory needed for these operations is more like N * new_size + old_size instead of N * (new_size + old_size).

If you have a specific code snippet that is problematic ie causes OOM or is slower than can be please file an issue and we can revisit this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants